{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.9.0\n",
      "\n",
      "torch 1.3.1\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# LeNet-5 MNIST Digits Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook implements the classic LeNet-5 convolutional network [1] and applies it to MNIST digit classification. The basic architecture is shown in the figure below:\n",
    "\n",
    "![](../images/lenet/lenet-5_1.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "LeNet-5 is commonly regarded as the pioneer of convolutional neural networks, consisting of a very simple architecture (by modern standards). In total, LeNet-5 consists of only 7 layers. 3 out of these 7 layers are convolutional layers (C1, C3, C5), which are connected by two average pooling layers (S2 & S4). The penultimate layer is a fully connexted layer (F6), which is followed by the final output layer. The additional details are summarized below:\n",
    "\n",
    "- All convolutional layers use 5x5 kernels with stride 1.\n",
    "- The two average pooling (subsampling) layers are 2x2 pixels wide with stride 1.\n",
    "- Throughrout the network, tanh sigmoid activation functions are used. (**In this notebook, we replace these with ReLU activations**)\n",
    "- The output layer uses 10 custom Euclidean Radial Basis Function neurons for the output layer. (**In this notebook, we replace these with softmax activations**)\n",
    "- The input size is 32x32; here, we rescale the MNIST images from 28x28 to 32x32 to match this input dimension. Alternatively, we would have to change the \n",
    "achieve error rate below 1% on the MNIST data set, which was very close to the state of the art at the time (produced by a boosted ensemble of three LeNet-4 networks).\n",
    "\n",
    "\n",
    "### References\n",
    "\n",
    "- [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, november 1998."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.001\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 10\n",
    "\n",
    "# Architecture\n",
    "NUM_FEATURES = 32*32\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:1\"\n",
    "GRAYSCALE = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 1, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "resize_transform = transforms.Compose([transforms.Resize((32, 32)),\n",
    "                                       transforms.ToTensor()])\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=resize_transform,\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=resize_transform)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(DEVICE)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "for epoch in range(2):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class LeNet5(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes, grayscale=False):\n",
    "        super(LeNet5, self).__init__()\n",
    "        \n",
    "        self.grayscale = grayscale\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "        if self.grayscale:\n",
    "            in_channels = 1\n",
    "        else:\n",
    "            in_channels = 3\n",
    "\n",
    "        self.features = nn.Sequential(\n",
    "            \n",
    "            nn.Conv2d(in_channels, 6, kernel_size=5),\n",
    "            nn.Tanh(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.Conv2d(6, 16, kernel_size=5),\n",
    "            nn.Tanh(),\n",
    "            nn.MaxPool2d(kernel_size=2)\n",
    "        )\n",
    "\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(16*5*5, 120),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(84, num_classes),\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        logits = self.classifier(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = LeNet5(NUM_CLASSES, GRAYSCALE)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 0000/0469 | Cost: 2.3055\n",
      "Epoch: 001/010 | Batch 0050/0469 | Cost: 0.5465\n",
      "Epoch: 001/010 | Batch 0100/0469 | Cost: 0.3708\n",
      "Epoch: 001/010 | Batch 0150/0469 | Cost: 0.3407\n",
      "Epoch: 001/010 | Batch 0200/0469 | Cost: 0.1298\n",
      "Epoch: 001/010 | Batch 0250/0469 | Cost: 0.1856\n",
      "Epoch: 001/010 | Batch 0300/0469 | Cost: 0.0940\n",
      "Epoch: 001/010 | Batch 0350/0469 | Cost: 0.1851\n",
      "Epoch: 001/010 | Batch 0400/0469 | Cost: 0.1425\n",
      "Epoch: 001/010 | Batch 0450/0469 | Cost: 0.0623\n",
      "Epoch: 001/010 | Train: 96.658%\n",
      "Time elapsed: 0.30 min\n",
      "Epoch: 002/010 | Batch 0000/0469 | Cost: 0.0659\n",
      "Epoch: 002/010 | Batch 0050/0469 | Cost: 0.1018\n",
      "Epoch: 002/010 | Batch 0100/0469 | Cost: 0.0810\n",
      "Epoch: 002/010 | Batch 0150/0469 | Cost: 0.1708\n",
      "Epoch: 002/010 | Batch 0200/0469 | Cost: 0.0639\n",
      "Epoch: 002/010 | Batch 0250/0469 | Cost: 0.0769\n",
      "Epoch: 002/010 | Batch 0300/0469 | Cost: 0.0425\n",
      "Epoch: 002/010 | Batch 0350/0469 | Cost: 0.0942\n",
      "Epoch: 002/010 | Batch 0400/0469 | Cost: 0.0303\n",
      "Epoch: 002/010 | Batch 0450/0469 | Cost: 0.0688\n",
      "Epoch: 002/010 | Train: 98.223%\n",
      "Time elapsed: 0.60 min\n",
      "Epoch: 003/010 | Batch 0000/0469 | Cost: 0.0867\n",
      "Epoch: 003/010 | Batch 0050/0469 | Cost: 0.0323\n",
      "Epoch: 003/010 | Batch 0100/0469 | Cost: 0.0311\n",
      "Epoch: 003/010 | Batch 0150/0469 | Cost: 0.0590\n",
      "Epoch: 003/010 | Batch 0200/0469 | Cost: 0.0507\n",
      "Epoch: 003/010 | Batch 0250/0469 | Cost: 0.0484\n",
      "Epoch: 003/010 | Batch 0300/0469 | Cost: 0.0492\n",
      "Epoch: 003/010 | Batch 0350/0469 | Cost: 0.1143\n",
      "Epoch: 003/010 | Batch 0400/0469 | Cost: 0.0164\n",
      "Epoch: 003/010 | Batch 0450/0469 | Cost: 0.0303\n",
      "Epoch: 003/010 | Train: 98.735%\n",
      "Time elapsed: 0.90 min\n",
      "Epoch: 004/010 | Batch 0000/0469 | Cost: 0.1143\n",
      "Epoch: 004/010 | Batch 0050/0469 | Cost: 0.0239\n",
      "Epoch: 004/010 | Batch 0100/0469 | Cost: 0.0171\n",
      "Epoch: 004/010 | Batch 0150/0469 | Cost: 0.0102\n",
      "Epoch: 004/010 | Batch 0200/0469 | Cost: 0.0484\n",
      "Epoch: 004/010 | Batch 0250/0469 | Cost: 0.0436\n",
      "Epoch: 004/010 | Batch 0300/0469 | Cost: 0.0156\n",
      "Epoch: 004/010 | Batch 0350/0469 | Cost: 0.0610\n",
      "Epoch: 004/010 | Batch 0400/0469 | Cost: 0.0331\n",
      "Epoch: 004/010 | Batch 0450/0469 | Cost: 0.1403\n",
      "Epoch: 004/010 | Train: 98.898%\n",
      "Time elapsed: 1.20 min\n",
      "Epoch: 005/010 | Batch 0000/0469 | Cost: 0.0127\n",
      "Epoch: 005/010 | Batch 0050/0469 | Cost: 0.0542\n",
      "Epoch: 005/010 | Batch 0100/0469 | Cost: 0.0505\n",
      "Epoch: 005/010 | Batch 0150/0469 | Cost: 0.0051\n",
      "Epoch: 005/010 | Batch 0200/0469 | Cost: 0.0455\n",
      "Epoch: 005/010 | Batch 0250/0469 | Cost: 0.0048\n",
      "Epoch: 005/010 | Batch 0300/0469 | Cost: 0.0517\n",
      "Epoch: 005/010 | Batch 0350/0469 | Cost: 0.0703\n",
      "Epoch: 005/010 | Batch 0400/0469 | Cost: 0.0785\n",
      "Epoch: 005/010 | Batch 0450/0469 | Cost: 0.0187\n",
      "Epoch: 005/010 | Train: 99.377%\n",
      "Time elapsed: 1.49 min\n",
      "Epoch: 006/010 | Batch 0000/0469 | Cost: 0.0363\n",
      "Epoch: 006/010 | Batch 0050/0469 | Cost: 0.0069\n",
      "Epoch: 006/010 | Batch 0100/0469 | Cost: 0.0156\n",
      "Epoch: 006/010 | Batch 0150/0469 | Cost: 0.0714\n",
      "Epoch: 006/010 | Batch 0200/0469 | Cost: 0.0099\n",
      "Epoch: 006/010 | Batch 0250/0469 | Cost: 0.0362\n",
      "Epoch: 006/010 | Batch 0300/0469 | Cost: 0.0044\n",
      "Epoch: 006/010 | Batch 0350/0469 | Cost: 0.0232\n",
      "Epoch: 006/010 | Batch 0400/0469 | Cost: 0.0093\n",
      "Epoch: 006/010 | Batch 0450/0469 | Cost: 0.0922\n",
      "Epoch: 006/010 | Train: 99.440%\n",
      "Time elapsed: 1.79 min\n",
      "Epoch: 007/010 | Batch 0000/0469 | Cost: 0.0095\n",
      "Epoch: 007/010 | Batch 0050/0469 | Cost: 0.0074\n",
      "Epoch: 007/010 | Batch 0100/0469 | Cost: 0.0051\n",
      "Epoch: 007/010 | Batch 0150/0469 | Cost: 0.0113\n",
      "Epoch: 007/010 | Batch 0200/0469 | Cost: 0.0274\n",
      "Epoch: 007/010 | Batch 0250/0469 | Cost: 0.0227\n",
      "Epoch: 007/010 | Batch 0300/0469 | Cost: 0.0294\n",
      "Epoch: 007/010 | Batch 0350/0469 | Cost: 0.0069\n",
      "Epoch: 007/010 | Batch 0400/0469 | Cost: 0.0236\n",
      "Epoch: 007/010 | Batch 0450/0469 | Cost: 0.0031\n",
      "Epoch: 007/010 | Train: 99.573%\n",
      "Time elapsed: 2.09 min\n",
      "Epoch: 008/010 | Batch 0000/0469 | Cost: 0.0031\n",
      "Epoch: 008/010 | Batch 0050/0469 | Cost: 0.0054\n",
      "Epoch: 008/010 | Batch 0100/0469 | Cost: 0.0035\n",
      "Epoch: 008/010 | Batch 0150/0469 | Cost: 0.0074\n",
      "Epoch: 008/010 | Batch 0200/0469 | Cost: 0.0306\n",
      "Epoch: 008/010 | Batch 0250/0469 | Cost: 0.0180\n",
      "Epoch: 008/010 | Batch 0300/0469 | Cost: 0.0049\n",
      "Epoch: 008/010 | Batch 0350/0469 | Cost: 0.0099\n",
      "Epoch: 008/010 | Batch 0400/0469 | Cost: 0.0173\n",
      "Epoch: 008/010 | Batch 0450/0469 | Cost: 0.0096\n",
      "Epoch: 008/010 | Train: 99.738%\n",
      "Time elapsed: 2.38 min\n",
      "Epoch: 009/010 | Batch 0000/0469 | Cost: 0.0015\n",
      "Epoch: 009/010 | Batch 0050/0469 | Cost: 0.0045\n",
      "Epoch: 009/010 | Batch 0100/0469 | Cost: 0.0078\n",
      "Epoch: 009/010 | Batch 0150/0469 | Cost: 0.0007\n",
      "Epoch: 009/010 | Batch 0200/0469 | Cost: 0.0129\n",
      "Epoch: 009/010 | Batch 0250/0469 | Cost: 0.0139\n",
      "Epoch: 009/010 | Batch 0300/0469 | Cost: 0.0031\n",
      "Epoch: 009/010 | Batch 0350/0469 | Cost: 0.0044\n",
      "Epoch: 009/010 | Batch 0400/0469 | Cost: 0.0066\n",
      "Epoch: 009/010 | Batch 0450/0469 | Cost: 0.0011\n",
      "Epoch: 009/010 | Train: 99.795%\n",
      "Time elapsed: 2.68 min\n",
      "Epoch: 010/010 | Batch 0000/0469 | Cost: 0.0046\n",
      "Epoch: 010/010 | Batch 0050/0469 | Cost: 0.0334\n",
      "Epoch: 010/010 | Batch 0100/0469 | Cost: 0.0059\n",
      "Epoch: 010/010 | Batch 0150/0469 | Cost: 0.0201\n",
      "Epoch: 010/010 | Batch 0200/0469 | Cost: 0.0132\n",
      "Epoch: 010/010 | Batch 0250/0469 | Cost: 0.0234\n",
      "Epoch: 010/010 | Batch 0300/0469 | Cost: 0.0085\n",
      "Epoch: 010/010 | Batch 0350/0469 | Cost: 0.0073\n",
      "Epoch: 010/010 | Batch 0400/0469 | Cost: 0.0029\n",
      "Epoch: 010/010 | Batch 0450/0469 | Cost: 0.0112\n",
      "Epoch: 010/010 | Train: 99.840%\n",
      "Time elapsed: 2.97 min\n",
      "Total Training Time: 2.97 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
    "                   %(epoch+1, NUM_EPOCHS, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
    "              epoch+1, NUM_EPOCHS, \n",
    "              compute_accuracy(model, train_loader, device=DEVICE)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 98.93%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQZUlEQVR4nO3dfYxUZZbH8e+h7VZAAYEWO7xsI6KOmaygHYKCE3R2lSWTIMlINMEoMcNkMyarGf/wJUE32Rhns2r4Y+MGFxzcIOqOGonRHV/ihhiMQ+MiL+IqIusgpF9EAxoU6T77R10yLdZT3V1Vt6q7z++TdLr6OXX7nlz49a26T917zd0RkZFvVL0bEJHaUNhFglDYRYJQ2EWCUNhFglDYRYI4o5KFzWwxsAZoAP7d3R8u9fzJkyd7a2trJasUkRIOHDhAd3e3FauVHXYzawD+Ffhb4CCwzcw2u/sHqWVaW1tpb28vd5Ui0o+2trZkrZKX8fOAfe6+391PAM8ASyv4fSKSo0rCPhX4c5+fD2ZjIjIEVRL2Yu8LfvTZWzNbZWbtZtbe1dVVwepEpBKVhP0gML3Pz9OAQ6c/yd3Xunubu7c1NzdXsDoRqUQlYd8GzDazmWbWBNwEbK5OWyJSbWUfjXf3k2Z2B/BHClNv6919T9U6E5Gqqmie3d1fAV6pUi8ikiN9gk4kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kiIruCGNmB4BjQA9w0t3Td4IXkbqqKOyZa9y9uwq/R0RypJfxIkFUGnYHXjOz7Wa2qhoNiUg+Kn0Zv8DdD5nZecDrZvahu2/p+4Tsj8AqgBkzZlS4OhEpV0V7dnc/lH3vBF4E5hV5zlp3b3P3tubm5kpWJyIVKDvsZjbWzM459Ri4DthdrcZEpLoqeRk/BXjRzE79nqfd/b+q0pWIVF3ZYXf3/cBlVexFRHKkqTeRIBR2kSAUdpEgFHaRIBR2kSCqcSLMiOLuyVpvb2/R8Z6enuQy2dTkoI0alf47XOp3pmrl9iEjh/bsIkEo7CJBKOwiQSjsIkEo7CJB6Gj8ab777rtkbceOHUXHn3rqqeQy48aNS9bGjh2brC1cuDBZu/jii5O18ePHD3pdEoP27CJBKOwiQSjsIkEo7CJBKOwiQSjsIkFo6u0033zzTbK2evXqouPbt29PLlPqBJSGhoZk7cknn0zWJk2alKylLtc9ki/jfcYZ6f/GU6dOLTp+4403JpeZMmVKWesa6rRnFwlCYRcJQmEXCUJhFwlCYRcJQmEXCaLfeQQzWw/8Auh0959mYxOBZ4FW4ACw3N2/zK/N2jnrrLOStRUrVhQdnzt3bnKZUtM4X331VbL22WefJWsffvhhsrZ169ZBjQNMnDgxWevu7k7WSl17L6XUdGOpbV9quVLTpRMmTBjUOMDy5cuTtZE+9fZ7YPFpY/cAb7r7bODN7GcRGcL6DXt2v/Ujpw0vBTZkjzcAN1S5LxGpsnLfs09x98MA2ffzqteSiOQh9wN0ZrbKzNrNrL2rqyvv1YlIQrlh7zCzFoDse2fqie6+1t3b3L2tubm5zNWJSKXKDftm4Nbs8a3AS9VpR0TyMpCpt03AImCymR0EHgAeBp4zs9uBz4D0KUTDTKnpn6VLlxYdv/baa5PLjBkzJlk7ceJEsnbs2LFkraOjI1n76KOPio53diZffDF79uxkbffu3clatafeSp3N9/nnnydra9asSdZS2/Hrr79OLlPqFmDDWb9hd/ebE6WfV7kXEcmRPkEnEoTCLhKEwi4ShMIuEoTCLhLE8D2FJyelpoZS91FLjVfi/PPPT9ZmzZqVrF1xxRVFx0vdw67UfeCuueaaZK23tzdZSxk1Kr1/KTWVt2XLlmSt1JloqQ9yXXnllWX9vuFMe3aRIBR2kSAUdpEgFHaRIBR2kSAUdpEgRuYcwwhXavpq9OjRgxrvT6mLUZaj1HTd/v37k7U33ngjWWtsbEzWUmcqXnLJJcllmpqakrXhTHt2kSAUdpEgFHaRIBR2kSAUdpEgdDReaur48ePJ2muvvZasbdy4MVmbNm1asrZy5cqi46WuNWhmydpwpj27SBAKu0gQCrtIEAq7SBAKu0gQCrtIEAO5/dN64BdAp7v/NBt7EPgVcOq2rPe5+yt5NSnDT+qEl7179yaXeeutt5K1kydPJmulbl/V0tJSdHykTq+VMpA9+++BxUXGH3P3OdmXgi4yxPUbdnffAhypQS8ikqNK3rPfYWY7zWy9mZ1btY5EJBflhv1xYBYwBzgMPJJ6opmtMrN2M2vv6upKPU1EclZW2N29w9173L0XeAKYV+K5a929zd3bUhfsF5H8lRV2M+t7iHMZsLs67YhIXgYy9bYJWARMNrODwAPAIjObAzhwAPh1jj3KMHTkSPFjuk8//XRymZdffjlZW7RoUbL2yCPJd5HJa+iVuo7fSNVv2N395iLD63LoRURyFO/Pm0hQCrtIEAq7SBAKu0gQCrtIELrgpOTik08+KTr+8ccfJ5eZNGlSsrZw4cJkbfr06claxCm2FG0JkSAUdpEgFHaRIBR2kSAUdpEgFHaRIDT1JmXr6elJ1rZv3150fN++fcllrr766mRt2bJlyVpTU1OyJn+hPbtIEAq7SBAKu0gQCrtIEAq7SBA6Gi9l+/TTT5O1t99+u+h46rZQAFdddVWyNnPmzGQt4q2cyqE9u0gQCrtIEAq7SBAKu0gQCrtIEAq7SBADuf3TdOAp4HygF1jr7mvMbCLwLNBK4RZQy939y/xalby4e7J29OjRZG3duvSNgbZu3Vp0vNTJLqVu8TR69OhkTQZmIHv2k8Bv3f0nwHzgN2Z2KXAP8Ka7zwbezH4WkSGq37C7+2F3fy97fAzYC0wFlgIbsqdtAG7Iq0kRqdyg3rObWSswF3gXmOLuh6HwBwE4r9rNiUj1DDjsZnY28Dxwp7un38j9eLlVZtZuZu1dXV3l9CgiVTCgsJtZI4Wgb3T3F7LhDjNryeotQGexZd19rbu3uXtbc3NzNXoWkTL0G3YrnGWwDtjr7o/2KW0Gbs0e3wq8VP32RKRaBnLW2wLgFmCXme3Ixu4DHgaeM7Pbgc+AG/NpUaqh1PTaiRMnkrVXX301Wdu0adOg13fdddcll7nggguSNalcv2F397eB1DmEP69uOyKSF32CTiQIhV0kCIVdJAiFXSQIhV0kCF1wMojvv/8+WSt14ciHHnooWevu7k7Wli9fXnR8/vz5yWXGjh2brEnltGcXCUJhFwlCYRcJQmEXCUJhFwlCYRcJQlNvI0zqbLMjR44kl7ntttuStQ8++CBZu+iii5K1lStXFh2fMWNGchnJl/bsIkEo7CJBKOwiQSjsIkEo7CJB6Gj8CJM64eXQoUPJZdrb25O1np6eZO3ee+9N1i677LKi401NTcllJF/as4sEobCLBKGwiwShsIsEobCLBKGwiwTR79SbmU0HngLOB3qBte6+xsweBH4FnLo1633u/kpejcpflLpd0549e4qO33333cllGhsbk7X7778/WVu8eHGylrqeXOHWgVIPA5lnPwn81t3fM7NzgO1m9npWe8zd/yW/9kSkWgZyr7fDwOHs8TEz2wtMzbsxEamuQb1nN7NWYC7wbjZ0h5ntNLP1ZnZulXsTkSoacNjN7GzgeeBOdz8KPA7MAuZQ2PM/klhulZm1m1l7V1dXsaeISA0MKOxm1kgh6Bvd/QUAd+9w9x537wWeAOYVW9bd17p7m7u3NTc3V6tvERmkfsNuhcOn64C97v5on/GWPk9bBuyufnsiUi0DORq/ALgF2GVmO7Kx+4CbzWwO4MAB4Ne5dCg/Uup6cps3by46/s477ySXKTX1tmTJkmRt3LhxyVpDQ0OyJvUxkKPxbwPFJkc1py4yjOgTdCJBKOwiQSjsIkEo7CJBKOwiQeiCk0PU8ePHk7Vt27Ylay+99FLR8W+//Ta5zJlnnpmsjR8/PlkbNUr7iuFE/1oiQSjsIkEo7CJBKOwiQSjsIkEo7CJBaOptiPriiy+StS1btiRru3btKjpe6iy0CRMmJGulzoiT4UV7dpEgFHaRIBR2kSAUdpEgFHaRIBR2kSA09TZEHT16NFnr6OhI1lL3UpsxY0ZymRUrViRr556bvveHznobXvSvJRKEwi4ShMIuEoTCLhKEwi4SRL9H483sLGALcGb2/D+4+wNmNhN4BpgIvAfc4u4n8mw2kjFjxiRrF154YbJ2/fXXFx1fsGBBcpm77rorWWtqakrWUkf+ZWgayJ79O+Bad7+Mwu2ZF5vZfOB3wGPuPhv4Erg9vzZFpFL9ht0Lvs5+bMy+HLgW+EM2vgG4IZcORaQqBnp/9obsDq6dwOvAJ8BX7n4ye8pBYGo+LYpINQwo7O7e4+5zgGnAPOAnxZ5WbFkzW2Vm7WbW3tXVVX6nIlKRQR2Nd/evgP8G5gMTzOzUAb5pwKHEMmvdvc3d25qbmyvpVUQq0G/YzazZzCZkj0cDfwPsBd4Cfpk97Vag+K1IRGRIGMiJMC3ABjNroPDH4Tl3f9nMPgCeMbN/Av4HWJdjn+G0trYma6tXr65dIzJi9Bt2d98JzC0yvp/C+3cRGQb0CTqRIBR2kSAUdpEgFHaRIBR2kSDMvegH3/JZmVkX8H/Zj5OB7pqtPE19/JD6+KHh1sdfuXvRT6/VNOw/WLFZu7u31WXl6kN9BOxDL+NFglDYRYKoZ9jX1nHdfamPH1IfPzRi+qjbe3YRqS29jBcJoi5hN7PFZva/ZrbPzO6pRw9ZHwfMbJeZ7TCz9hqud72ZdZrZ7j5jE83sdTP7OPuevu9Svn08aGafZ9tkh5ktqUEf083sLTPba2Z7zOwfsvGabpMSfdR0m5jZWWb2JzN7P+vjH7PxmWb2brY9njWz9NVAi3H3mn4BDRQua3UB0AS8D1xa6z6yXg4Ak+uw3p8BlwO7+4z9M3BP9vge4Hd16uNB4O4ab48W4PLs8TnAR8Cltd4mJfqo6TYBDDg7e9wIvEvhgjHPATdl4/8G/P1gfm899uzzgH3uvt8Ll55+Blhahz7qxt23AEdOG15K4cKdUKMLeCb6qDl3P+zu72WPj1G4OMpUarxNSvRRU15Q9Yu81iPsU4E/9/m5nherdOA1M9tuZqvq1MMpU9z9MBT+0wHn1bGXO8xsZ/YyP/e3E32ZWSuF6ye8Sx23yWl9QI23SR4Xea1H2IvdWaBeUwIL3P1y4O+A35jZz+rUx1DyODCLwj0CDgOP1GrFZnY28Dxwp7un71ld+z5qvk28gou8ptQj7AeB6X1+Tl6sMm/ufij73gm8SH2vvNNhZi0A2ffOejTh7h3Zf7Re4AlqtE3MrJFCwDa6+wvZcM23SbE+6rVNsnUP+iKvKfUI+zZgdnZksQm4Cdhc6ybMbKyZnXPqMXAdsLv0UrnaTOHCnVDHC3ieCldmGTXYJla4j9Q6YK+7P9qnVNNtkuqj1tskt4u81uoI42lHG5dQONL5CXB/nXq4gMJMwPvAnlr2AWyi8HLwewqvdG4HJgFvAh9n3yfWqY//AHYBOymEraUGfSyk8JJ0J7Aj+1pS621Soo+abhPgrylcxHUnhT8sq/v8n/0TsA/4T+DMwfxefYJOJAh9gk4kCIVdJAiFXSQIhV0kCIVdJAiFXSQIhV0kCIVdJIj/B2WuJureEJR4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for batch_idx, (features, targets) in enumerate(test_loader):\n",
    "\n",
    "    features = features\n",
    "    targets = targets\n",
    "    break\n",
    "    \n",
    "    \n",
    "nhwc_img = np.transpose(features[0], axes=(1, 2, 0))\n",
    "nhw_img = np.squeeze(nhwc_img.numpy(), axis=2)\n",
    "plt.imshow(nhw_img, cmap='Greys');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability 7 100.00%\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "logits, probas = model(features.to(device)[0, None])\n",
    "print('Probability 7 %.2f%%' % (probas[0][7]*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       1.3.1\n",
      "numpy       1.17.4\n",
      "matplotlib  3.1.0\n",
      "torchvision 0.4.2\n",
      "pandas      0.24.2\n",
      "PIL.Image   6.2.1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
